################################################################################
################################################################################
##### CROSS VALIDATION OF DYNAMIC GROUP-LEVEL ITEM-RESPONSE MODEL ##############
################################################################################
################################################################################

### (1) Loop over 10 simulations
### (2) Read in 10 different random training and validation partitions
### (3) Estimate pooled and separated model on training data
### (4) Compare out-of-sample performance on validation data
###
### N.B. This code takes several days to run and requires 10 parallel cores.
### If you wish to replicate our results only approximately, reduce the number
### of parellel chains or (if performance doesn't suffer) the number of
### iterations.

################################################################################
#### SET-UP CODE (LIBRARIES, FUNCTIONS, ETC.) ##################################
################################################################################

setwd("CrossValidation")

### LIBRARIES
library(rstan)
library(foreign)
library(pbapply)
library(parallel)
library(TeachingDemos)
library(ggplot2)
library(reshape2)
library(car)
library(plyr)
library(dplyr)
library(survey)
library(stringr)
library(data.table)
library(matlab)
detectCores()

### FUNCTIONS
LoopProgress <- function(index, interval=1) {
    if (index %% interval == 0) print (index)
}

svy.yr.range <- 1976:2010

### STAN CODE
cat(stan.code <- '
// Stan code for dynamic group-level IRT model
data {
  int<lower=1> G; // number of covariate groups
  int<lower=1> Q; // number of items/questions
  int<lower=1> T; // number of years
  int<lower=1> N; // number of observed cells
  int<lower=1> S; // number of geographic units (e.g., states)
  int<lower=1> P; // number of hierarchical parameters, including geographic
  int<lower=1> H; // number of predictors for geographic unit effects
  int<lower=1> H_prior; // number of predictors for geographic unit effects (t=1)
  int<lower=1> D; // number of difficulty parameters per question
  int<lower=0,upper=1> constant_item; // indicator for constant item parameters
  int<lower=0,upper=1> separate_years; // indicator for no over-time smoothing
  int s_vec[N]; // long vector of responses
  int n_vec[N]; // long vector of counts
  int<lower=0> MMM[T, Q, G]; // missingness array
  matrix<lower=0, upper=1>[G, P] XX; // indicator matrix for hierarchical vars.
  row_vector[H] ZZ[T, S]; // data for geographic model
  row_vector[H_prior] ZZ_prior[T, S]; // data for geographic model
}
transformed data {
}
parameters {
  vector[Q] diff_raw[D]; // raw difficulty
  vector<lower=0>[Q] disc_raw; // discrimination
  vector[T] xi; // national mean (common intercept)
  vector[P] gamma[T]; // hierarchical parameters
  vector[T] delta_lag; // weight placed on geo. effects from prev. period
  vector[H] delta_pred[T]; // weight on geographic predictors
  vector[H_prior] delta_pred_prior[T]; // weight on geographic predictors (t=1)
  vector[G] theta_bar[T]; // group mean ability
  vector<lower=0>[T] sd_theta_bar; // sd of group ability means (by period)
  vector<lower=0>[T] sd_theta; // sd of abilities (by period)
  real<lower=0> sd_geo; // prior sd of geographic effects
  real<lower=0> sd_geo_prior; // prior sd of geographic effects (t=1)
  real<lower=0> sd_demo; // sd of demographic effecs
  real<lower=0> sd_innov_delta; // innovation sd of delta_pred and delta_lag
  real<lower=0> sd_innov_logsd; // innovation sd of sd_theta
  real<lower=0> sd_innov_gamma; // innovation sd of gamma, xi, and (opt.) diff
}
transformed parameters {
  vector[Q] diff[D]; // adjusted difficulty
  vector[Q] kappa[D]; // threshold
  vector<lower=0>[Q] disc; // normalized discrimination
  vector<lower=0>[Q] sd_item; // item standard deviation
  vector<lower=0>[Q] var_item; // item variance
  vector<lower=0>[T] var_theta; // variance of abilities
  vector[G] xb_theta_bar[T]; // linear predictor for group means
  vector[G] z[T, Q]; // array of vectors of group deviates
  real prob[T, Q, G]; // array of probabilities
  // Identify model by rescaling item parameters (Fox 2010, pp. 88-89)
  // scale (product = 1)
  disc <- disc_raw * pow(exp(sum(log(disc_raw))), (-inv(Q)));
  for (q in 1:Q) {
    sd_item[q] <- inv(disc[q]); // item standard deviations
  }
  for (d in 1:D) {
    // location (mean in first year = 0)
    diff[d] <- diff_raw[d] - mean(diff_raw[1]); 
    kappa[d] <- diff[d] ./ disc; // item thresholds
  }
  var_item <- sd_item .* sd_item; // item variances
  // Abilities
  var_theta <- sd_theta .* sd_theta; // within-group variances of abilities
  for (t in 1:T) { // loop over years 
    xb_theta_bar[t] <- xi[t] + XX * gamma[t]; // Gx1 = GxP * Px1
    for (q in 1:Q) { // loop over questions
      real denominator; //
      denominator <- sqrt(var_theta[t] + var_item[q]);
      // Group-level IRT model
      if (constant_item == 0) {
        z[t, q] <- (theta_bar[t] - kappa[t][q]) / denominator;
      }
      if (constant_item == 1) {
        z[t, q] <- (theta_bar[t] - kappa[1][q]) / denominator;
      }
      for (g in 1:G) { // loop over groups
        prob[t, q, g] <- Phi_approx(z[t, q, g]); // fast approx. of normal CDF
      } // end group loop
    } // end question loop
  } // end year loop
  // Convert counts and probabilities from array to vector
}
model {
  // TEMPORARY VARIABLES
  real prob_vec[N]; // long vector of probabilities (empty cells omitted)
  int pos;
  pos <- 0;
  // PRIORS
  if (constant_item == 1) {
    diff_raw[1] ~ normal(0, 1); // item difficulty (constant)
  }
  disc_raw ~ lognormal(0, 1); // item discrimination
  sd_geo ~ cauchy(0, 2.5); // sd of geographic effects
  sd_geo_prior ~ cauchy(0, 2.5); // prior sd of geographic effects
  sd_demo ~ cauchy(0, 2.5); // prior sd of demographic parameters
  sd_innov_delta ~ cauchy(0, 2.5); // innovation sd of delta_pred/delta_lag
  sd_innov_gamma ~ cauchy(0, 2.5); // innovation sd. of gamma, xi, and diff
  sd_innov_logsd ~ cauchy(0, 2.5); // innovation sd of theta_sd
  for (t in 1:T) { // loop over years
    if (separate_years == 1) { // Estimate model anew each period
      xi[t] ~ normal(0, 10); // intercept
      for (p in 1:P) { // Loop over individual predictors (gammas)
        if (p <= S) gamma[t][p] ~ normal(ZZ[t][p]*delta_pred[t], sd_geo);
        if (p > S) gamma[t][p] ~ normal(0, sd_demo);
      }
    }
    if (t == 1) {
      if (constant_item == 0) {
        diff_raw[t] ~ normal(0, 1); // item difficulty
      }
      // Priors for first period
      sd_theta_bar[t] ~ cauchy(0, 2.5);
      sd_theta[t] ~ cauchy(0, 2.5);
      delta_lag[t] ~ normal(0.5, 1);
      delta_pred[t] ~ normal(0, 10); 
      delta_pred_prior[t] ~ normal(0, 10); 
      if (separate_years == 0) {
        xi[t] ~ normal(0, 10); // intercept
        for (p in 1:P) { // Loop over individual predictors (gammas)
          if (p <= S) gamma[t][p] ~ normal(ZZ_prior[t][p]*delta_pred_prior[t],
                                           sd_geo_prior);
          if (p > S) gamma[t][p] ~ normal(0, sd_demo);
        }
      }
    }
    if (t > 1) {
      // TRANSITION MODEL
      // Difficulty parameters (if not constant)
      if (constant_item == 0) { 
        diff_raw[t] ~ normal(diff_raw[t - 1], sd_innov_gamma); //
      }
      // predictors in geographic models (random walk)
      delta_lag[t] ~ normal(delta_lag[t - 1], sd_innov_delta);
      delta_pred[t] ~ normal(delta_pred[t - 1], sd_innov_delta);
      sd_theta_bar[t] ~ lognormal(log(sd_theta_bar[t - 1]), sd_innov_logsd);
      sd_theta[t] ~ lognormal(log(sd_theta[t - 1]), sd_innov_logsd);
      if (separate_years == 0) {
        // Dynamic linear model for hierarchical parameters
        xi[t] ~ normal(xi[t - 1], sd_innov_gamma); // intercept
        for (p in 1:P) { // Loop over individual predictors (gammas)
          if (p <= S) {
            gamma[t][p] ~ normal(delta_lag[t]*gamma[t - 1][p] +
                                 ZZ[t][p]*delta_pred[t], sd_innov_gamma); // 
          }
          if (p > S) gamma[t][p] ~ normal(gamma[t - 1][p], sd_innov_gamma);
        }
      }
    }
    // RESPONSE MODEL
    // Model for group means 
    // (See "transformed parameters" for definition of xb_theta_bar)
    theta_bar[t] ~ normal(xb_theta_bar[t], sd_theta_bar[t]); // group means
    for (q in 1:Q) { // loop over questions
      for (g in 1:G) { // loop over groups
        if (MMM[t, q, g] == 0) { // Use only if not missing
          pos <- pos + 1;
          prob_vec[pos] <- prob[t, q, g];
        }
      } // end group loop
    } // end question loop
  } // end time loop
  // Model for group responses
  s_vec ~ binomial(n_vec, prob_vec); // fully vectorized
}
generated quantities {
  vector<lower=0>[T] sd_total; //
  for (t in 1:T) {
    sd_total[t] <- sqrt(variance(theta_bar[t]) + square(sd_theta[t]));
  }
}')

pars.to.save <- c("theta_bar", "xi", "gamma", "delta_lag", "delta_pred",
                  "delta_pred_prior", "kappa", "sd_item", "sd_theta",
                  "sd_theta_bar", "sd_demo", "sd_geo", "sd_innov_gamma",
                  "sd_innov_delta", "sd_innov_logsd", "sd_total")

(n.iter <- 4e3/100)
(n.chain <- 10/5) ## number of parellel chains
(max.save <- 1e3)
(n.warm <- 2e3/100)
(n.thin <- ceiling((n.iter - n.warm) / (max.save / n.chain)))

training.fraction <- 1/4
nSims <- 10
s <- 1

################################################################################
#### LOOP OVER SIMULATIONS #####################################################
################################################################################

for (s in s:nSims) {
    cat('\nSIMULATION', s, 'OF', nSims, '\n')

    load(paste0('CV-StanData', s, '.RData'))
    G <- stan.dyn.data$G

################################################################################
#### ESTIMATE MODEL WITH OVER-TIME POOLING #####################################
################################################################################
    cat('\nPooling over time\n')
    
    date()
    system.time(stan.dyn.par <-
                mclapply(1:n.chain, mc.cores=n.chain, FUN=function(chain) {
                    stan(model_code=stan.code, data=stan.dyn.data, iter=n.iter,
                         chains=1, warmup=n.warm, thin=n.thin, verbose=FALSE,
                         chain_id=chain, refresh=max(floor(n.iter/100), 1),
                         pars=pars.to.save, seed=chain)
                })
                )
    date()
    
    ## check if any chains failed
    any(failed <- laply(stan.dyn.par, function (x) length(x@sim)) == 0)
    stan.dyn.par <- stan.dyn.par[!failed]
    ## check if any chains have unusually low standard deviations (not mixing)
    chain.sds <- laply(stan.dyn.par, function (X) median(apply(X, 2:3, sd)))
    scale(chain.sds)
    chains.to.drop <- FALSE
    which(chains.to.drop <- scale(chain.sds) < -2)
    
    stan.dyn.cmb <- sflist2stanfit(stan.dyn.par[!chains.to.drop])
    pars.dyn <- extract(stan.dyn.cmb, permuted=FALSE, inc_warmup=FALSE)
    (post.thin <- max(1, floor(dim(pars.dyn)[[1]]*dim(pars.dyn)[[2]] / 2000)))
    sub.idx <- seq(1, dim(pars.dyn)[[1]], post.thin)
    pars.dyn.mx <- as.matrix(pars.dyn[sub.idx, 1, ])
    for (i in 1:dim(pars.dyn)[[2]]) {
        if (i == 1) next
        print(i)
        pars.dyn.mx <- rbind(pars.dyn.mx, pars.dyn[sub.idx, i, ])
    }
    names(attributes(pars.dyn.mx)$dimnames) <- c("iterations", "parameters")
    pars.dyn.df <- data.frame(t(pars.dyn.mx))
    names(pars.dyn.df) <- paste0("Sim", seq_along(pars.dyn.df))

    print(stan.dyn.cmb, pars="lp__")
    print(stan.dyn.cmb, probs=c(0.05, 0.5, 0.95), digits=2, pars="theta_bar")
    print(stan.dyn.cmb, probs=c(0.05, 0.5, 0.95), digits=2, pars="kappa")
    print(stan.dyn.cmb, probs=c(0.05, 0.5, 0.95), digits=2, pars="sd_item")

### ITEM PARAMETERS
    ## Threshold
    kappa.dyn.mx <- pars.dyn.mx[, grepl("kappa\\[", colnames(pars.dyn.mx))]
    colnames(kappa.dyn.mx) <- colnames(train.props)
    kappa.dyn.est <- apply(kappa.dyn.mx, 2, median)
    ## Dispersion
    sd_item.dyn.mx <- pars.dyn.mx[, grep("sd_item\\[", colnames(pars.dyn.mx))]
    colnames(sd_item.dyn.mx) <- colnames(train.props)
    var_item.dyn.mx <- sd_item.dyn.mx^2
    var_item.dyn.est <- apply(var_item.dyn.mx, 2, median)

### LIBERALISM ESTIMATES
    ## Group means
    theta_bar.dyn.mx <-
        pars.dyn.mx[, grep("^theta_bar\\[", colnames(pars.dyn.mx))]
    colnames(theta_bar.dyn.mx) <- paste(rep(levels(group), each=T),
                                        rep(svy.yr.range, G), sep=".")
    NN.order <- match(rownames(train.props), colnames(theta_bar.dyn.mx))
    theta_bar.dyn.mx <- theta_bar.dyn.mx[, NN.order]
    theta_bar.dyn.est <- apply(theta_bar.dyn.mx, 2, median)
    ## Group variances
    sd_theta.dyn.mx <-
        pars.dyn.mx[, grep("^sd_theta\\[", colnames(pars.dyn.mx))]
    sd_theta.dyn.mx <- repmat(sd_theta.dyn.mx, c(1, G))
    colnames(sd_theta.dyn.mx) <- paste(rep(levels(group), each=T),
                                       rep(svy.yr.range, G), sep=".")
    sd_theta.dyn.mx <- sd_theta.dyn.mx[, NN.order]
    var_theta.dyn.mx <- sd_theta.dyn.mx^2
    var_theta.dyn.est <- apply(var_theta.dyn.mx, 2, median)
### PREDICTED PROPORTIONS
    pred.dyn.props <- array(NA, dim=dim(train.props),
                            dimnames=dimnames(train.props))
    for (q in 1:ncol(pred.dyn.props)) {
        LoopProgress(q, 10)
        zmat <- ((theta_bar.dyn.mx - kappa.dyn.mx[, q]) /
                 sqrt(var_theta.dyn.mx + var_item.dyn.mx[, q]))
        pred.dyn.props[, q] <- colMeans(pnorm(zmat))
    }
    is.na(pred.dyn.props) <- is.na(val.props)

################################################################################
#### ESTIMATE MODEL WITH NO OVER-TIME POOLING ##################################
################################################################################

    cat('\nNo pooling over time\n')
    
    stan.sep.data <- stan.dyn.data
    stan.sep.data$separate_years <- 1

    date()
    system.time(stan.sep.par <-
                mclapply(1:n.chain, mc.cores=n.chain, FUN=function(chain) {
                    stan(model_code=stan.code, data=stan.sep.data, iter=n.iter,
                         chains=1, warmup=n.warm, thin=n.thin, verbose=FALSE,
                         chain_id=chain, refresh=max(floor(n.iter/100), 1),
                         pars=pars.to.save, seed=chain)})
                )
    date()

    ## check if any chains failed
    any(failed <- laply(stan.sep.par, function (x) length(x@sim)) == 0)
    stan.sep.par <- stan.sep.par[!failed]
    ## check if any chains have unusually low standard deviations (not mixing)
    chain.sds <- laply(stan.sep.par, function (X) median(apply(X, 2:3, sd)))
    scale(chain.sds)
    chains.to.drop <- FALSE
    which(chains.to.drop <- scale(chain.sds) < -2)
    
    stan.sep.cmb <- sflist2stanfit(stan.sep.par[!chains.to.drop])
    pars.sep <- extract(stan.sep.cmb, permuted=FALSE, inc_warmup=FALSE)
    (post.thin <- max(1, floor(dim(pars.sep)[[1]]*dim(pars.sep)[[2]] / 2000)))
    sub.idx <- seq(1, dim(pars.sep)[[1]], post.thin)
    pars.sep.mx <- as.matrix(pars.sep[sub.idx, 1, ])
    for (i in 1:dim(pars.sep)[[2]]) {
        if (i == 1) next
        print(i)
        pars.sep.mx <- rbind(pars.sep.mx, pars.sep[sub.idx, i, ])
    }
    names(attributes(pars.sep.mx)$dimnames) <- c("iterations", "parameters")
    pars.sep.df <- data.frame(t(pars.sep.mx))
    names(pars.sep.df) <- paste0("Sim", seq_along(pars.sep.df))

    print(stan.sep.cmb, pars="lp__")
    print(stan.sep.cmb, probs=c(0.05, 0.5, 0.95), digits=2, pars="theta_bar")
    print(stan.sep.cmb, probs=c(0.05, 0.5, 0.95), digits=2, pars="kappa")
    print(stan.sep.cmb, probs=c(0.05, 0.5, 0.95), digits=2, pars="sd_item")

### ITEM PARAMETERS
    ## Threshold
    kappa.sep.mx <- pars.sep.mx[, grepl("kappa\\[", colnames(pars.sep.mx))]
    colnames(kappa.sep.mx) <- colnames(train.props)
    kappa.sep.est <- apply(kappa.sep.mx, 2, median)
    ## Dispersion
    sd_item.sep.mx <- pars.sep.mx[, grep("sd_item\\[", colnames(pars.sep.mx))]
    colnames(sd_item.sep.mx) <- colnames(train.props)
    var_item.sep.mx <- sd_item.sep.mx^2
    var_item.sep.est <- apply(var_item.sep.mx, 2, median)

### LIBERALISM ESTIMATES
    ## Group means
    theta_bar.sep.mx <-
        pars.sep.mx[, grep("^theta_bar\\[", colnames(pars.sep.mx))]
    colnames(theta_bar.sep.mx) <- paste(rep(levels(group), each=T),
                                        rep(svy.yr.range, G), sep=".")
    NN.order <- match(rownames(train.props), colnames(theta_bar.sep.mx))
    theta_bar.sep.mx <- theta_bar.sep.mx[, NN.order]
    theta_bar.sep.est <- apply(theta_bar.sep.mx, 2, median)
    ## Group variances
    sd_theta.sep.mx <-
        pars.sep.mx[, grep("^sd_theta\\[", colnames(pars.sep.mx))]
    sd_theta.sep.mx <- repmat(sd_theta.sep.mx, c(1, G))
    colnames(sd_theta.sep.mx) <- paste(rep(levels(group), each=T),
                                       rep(svy.yr.range, G), sep=".")
    sd_theta.sep.mx <- sd_theta.sep.mx[, NN.order]
    var_theta.sep.mx <- sd_theta.sep.mx^2
    var_theta.sep.est <- apply(var_theta.sep.mx, 2, median)
### PREDICTED PROPORTIONS
    pred.sep.props <- array(NA, dim=dim(train.props),
                            dimnames=dimnames(train.props))
    for (q in 1:ncol(pred.sep.props)) {
        LoopProgress(q, 10)
        zmat <- ((theta_bar.sep.mx - kappa.sep.mx[, q]) /
                 sqrt(var_theta.sep.mx + var_item.sep.mx[, q]))
        pred.sep.props[, q] <- colMeans(pnorm(zmat))
    }
    is.na(pred.sep.props) <- is.na(val.props)

################################################################################
#### OUT-OF-SAMPLE VALIDATION OF BOTH MODELS ###################################
################################################################################
    cat('\nVALIDATION\n')
    
### SMOOTHING MODEL
    ## TRAINING
    cor(as.vector(train.props), as.vector(pred.dyn.props), use='complete.obs')
    coef(lm(as.vector(train.props) ~ as.vector(pred.dyn.props)))[[2]]
    train.err.dyn <- train.props - pred.dyn.props
    mean(train.err.dyn, na.rm=TRUE) ## bias
    mean(abs(train.err.dyn), na.rm=TRUE) ## MAE
    sqrt(mean(train.err.dyn^2, na.rm=TRUE)) ## Root MSE
    ## OUT OF SAMPLE
    cor(as.vector(val.props), as.vector(pred.dyn.props), use='complete.obs')
    coef(lm(as.vector(val.props) ~ as.vector(pred.dyn.props)))[[2]]
    val.err.dyn <- val.props - pred.dyn.props
    (bias.dyn <- mean(val.err.dyn, na.rm=TRUE)) ## bias
    (mae.dyn <- mean(abs(val.err.dyn), na.rm=TRUE)) ## MAE
    (rmse.dyn <- sqrt(mean(val.err.dyn^2, na.rm=TRUE))) ## Root MSE

### NO-SMOOTHING MODEL
    ## TRAINING
    cor(as.vector(train.props), as.vector(pred.sep.props), use='complete.obs')
    coef(lm(as.vector(train.props) ~ as.vector(pred.sep.props)))[[2]]
    train.err.sep <- train.props - pred.sep.props
    mean(train.err.sep, na.rm=TRUE) ## bias
    mean(abs(train.err.sep), na.rm=TRUE) ## MAE
    sqrt(mean(train.err.sep^2, na.rm=TRUE)) ## Root MSE
    ## OUT OF SAMPLE
    cor(as.vector(val.props), as.vector(pred.sep.props), use='complete.obs')
    coef(lm(as.vector(val.props) ~ as.vector(pred.sep.props)))[[2]]
    val.err.sep <- val.props - pred.sep.props
    (bias.sep <- mean(val.err.sep, na.rm=TRUE)) ## bias
    (mae.sep <- mean(abs(val.err.sep), na.rm=TRUE)) ## MAE
    (rmse.sep <- sqrt(mean(val.err.sep^2, na.rm=TRUE))) ## Root MSE

### Comparison
    cat('\nOUT-OF-SAMPLE COMPARISON FOR SIMULATION', s, '\n')
    print(abs(bias.dyn) - abs(bias.sep))
    print(mae.dyn - mae.sep)
    print(rmse.dyn - rmse.sep)

    rm(list = c("stan.dyn.data", "train.props", "group", "val.props"))
}
